#
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
import copy
import pathlib
import tempfile
import unittest

import torch
import torch.nn
import torch.optim
import torch.testing

import fairdiplomacy.agents  # Avoiding import cycles.
from fairdiplomacy import pydipcc
from fairdiplomacy.models.consts import POWERS, MAX_SEQ_LEN
from fairdiplomacy.selfplay.exploit import TrainerState, NetTrainingState
from fairdiplomacy.selfplay.search.data_loader import (
    flatten_dict,
    unflatten_dict,
    compress_and_flatten,
    decompress_and_unflatten,
)
from fairdiplomacy.selfplay.search.search_utils import (
    power_prob_distributions_to_tensors_independent,
    compute_search_policy_cross_entropy_sampled,
    compute_search_policy_entropy,
    create_research_targets_single_rollout,
    evs_to_policy,
)
from fairdiplomacy.selfplay import rela
from fairdiplomacy.utils.order_idxs import ORDER_VOCABULARY
from fairdiplomacy.utils.thread_pool_encoding import FeatureEncoder
import nest


class CompressDecompressTest(unittest.TestCase):
    def test_flatten_dict(self):
        d = {"a": 1, "b": {"c": 2}}
        self.assertEqual(flatten_dict(d), {"a": 1, "b/c": 2})

    def test_unflatten_dict(self):
        d = {"a": 1, "b/c": 2}
        self.assertEqual(unflatten_dict(d), {"a": 1, "b": {"c": 2}})

    def test_compress_and_flatten(self):
        game = pydipcc.Game()

        observations = FeatureEncoder().encode_inputs([game], input_version=1)

        data = {"observations": observations}
        data = compress_and_flatten(data)
        self.assertTrue("observations/x_possible_actions" in data)
        self.assertEqual(data["observations/x_possible_actions"].dtype, torch.short)
        data = decompress_and_unflatten(data)
        self.assertEqual(observations["x_possible_actions"].dtype, torch.int32)
        self.assertEqual(data["observations"]["x_possible_actions"].dtype, torch.int32)


class BufferLoadSaveTest(unittest.TestCase):
    def test_load_asve(self):
        data = []
        buffer = rela.NestPrioritizedReplay(
            capacity=100, seed=0, alpha=1.0, beta=1.0, prefetch=False
        )
        for i in range(10):
            data.append({"a": torch.zeros(1) + i, "b": torch.zeros(1) + 2 + i})
            buffer.add_one(data[-1], 1)

        with tempfile.NamedTemporaryFile() as tmp_file:
            buffer.save(tmp_file.name)
            buffer2 = rela.NestPrioritizedReplay(
                capacity=100, seed=0, alpha=1.0, beta=1.0, prefetch=False
            )
            buffer2.load(tmp_file.name)

        self.assertEqual(buffer.size(), buffer2.size())
        # Load is not counted towards num_add.
        self.assertEqual(buffer2.num_add(), 0)
        read_data, _ = buffer2.get_all_content()

        self.assertEqual(data, read_data)


class TestSimplexDiscounting(unittest.TestCase):
    def test_full_discounting(self):
        episode_reward = torch.as_tensor([0.1, 0.9])
        expected_rewards = torch.as_tensor([[0.5, 0.5], [0.1, 0.9]])
        is_explore = torch.zeros([len(expected_rewards), len(episode_reward)], dtype=torch.bool)
        predicted_values = torch.zeros_like(expected_rewards)
        alive_powers = torch.ones_like(predicted_values).to(torch.bool)
        torch.testing.assert_allclose(
            create_research_targets_single_rollout(
                is_explore, episode_reward, predicted_values, alive_powers, 0.0
            ),
            expected_rewards,
        )

    def test_partial_discounting(self):
        episode_reward = torch.as_tensor([0.1, 0.9])
        expected_rewards = torch.as_tensor([[0.4, 0.6], [0.3, 0.7], [0.1, 0.9]])
        is_explore = torch.zeros_like(expected_rewards, dtype=torch.bool)
        predicted_values = torch.zeros_like(expected_rewards)
        alive_powers = torch.ones_like(predicted_values).to(torch.bool)
        torch.testing.assert_allclose(
            create_research_targets_single_rollout(
                is_explore, episode_reward, predicted_values, alive_powers, 0.5
            ),
            expected_rewards,
        )

    def test_many_countries_partial_discounting(self):
        episode_reward = torch.as_tensor([0.0, 0.0, 0.1, 0.9])
        expected_rewards = torch.as_tensor(
            [[0.125, 0.125, 0.275, 0.7 - 0.45 / 2], [0, 0, 0.3, 0.7], [0, 0, 0.1, 0.9]]
        )
        is_explore = torch.zeros_like(expected_rewards, dtype=torch.bool)
        predicted_values = torch.zeros_like(expected_rewards)
        alive_powers = torch.ones_like(predicted_values).to(torch.bool)
        alive_powers[1:, :2] = 0
        torch.testing.assert_allclose(
            create_research_targets_single_rollout(
                is_explore, episode_reward, predicted_values, alive_powers, 0.5
            ),
            expected_rewards,
        )


class SearchPolicyTest(unittest.TestCase):
    def test_xe_loss_initial_state(self):
        game = pydipcc.Game()

        observations = FeatureEncoder().encode_inputs([game], input_version=1)

        # Every power's policy is to do all holds
        all_power_prob_distributions = {}
        for power, locs in game.get_orderable_locations().items():
            action = tuple(game.get_all_possible_orders()[loc][0] for loc in locs)
            all_power_prob_distributions[power] = {action: 1.0}

        orders_tensor, probs_tensor = power_prob_distributions_to_tensors_independent(
            all_power_prob_distributions,
            2,
            observations["x_possible_actions"].squeeze(0),
            observations["x_in_adj_phase"].item(),
        )

        self.assertEqual(list(orders_tensor.shape), [7, 2, MAX_SEQ_LEN])
        self.assertEqual(list(probs_tensor.shape), [7, 2])
        # All ones.
        self.assertTrue((probs_tensor[:, 0] - 1 < 1e-4).all())
        # All zeros.
        self.assertTrue((probs_tensor[:, 1] < 1e-4).all())

        self.assertEqual(
            set(
                [
                    ORDER_VOCABULARY[observations["x_possible_actions"].squeeze(0)[0, i, x]]
                    for i, x in enumerate(orders_tensor[0, 0].tolist())
                    if x != -1
                ]
            ),
            set(list(all_power_prob_distributions[POWERS[0]])[0]),
        )

        def fake_model(*, teacher_force_orders, **kwargs):
            logits = torch.zeros(list(teacher_force_orders.shape) + [453])
            return None, None, logits, None

        fake_model.is_all_powers = lambda: False

        compute_search_policy_cross_entropy_sampled(
            fake_model,  # type: ignore
            # Adding extra dimensions for everything so that T=1 and B=1.
            nest.map(lambda x: x.unsqueeze(0), observations),
            orders_tensor.unsqueeze(0).unsqueeze(0),
            probs_tensor.unsqueeze(0).unsqueeze(0),
        )

        def fake_time_batch(tensor):
            shape = list(tensor.shape)
            shape[0] = 5
            shape[1] = 10
            return tensor.expand(shape)

        loss, _ = compute_search_policy_cross_entropy_sampled(
            fake_model,  # type: ignore
            nest.map(lambda x: fake_time_batch(x.unsqueeze(0)), observations),
            fake_time_batch(orders_tensor.unsqueeze(0).unsqueeze(0)),
            fake_time_batch(probs_tensor.unsqueeze(0).unsqueeze(0)),
        )

        npowers = len(POWERS)

        # Using all true mask should not change result
        loss_full_mask, _ = compute_search_policy_cross_entropy_sampled(
            fake_model,  # type: ignore
            nest.map(lambda x: fake_time_batch(x.unsqueeze(0)), observations),
            fake_time_batch(orders_tensor.unsqueeze(0).unsqueeze(0)),
            fake_time_batch(probs_tensor.unsqueeze(0).unsqueeze(0)),
            mask=torch.ones((5, 10, npowers), dtype=torch.bool),
        )
        self.assertEqual(loss.item(), loss_full_mask.item())

        # As every line is repeated, keepint only first won't change anything.
        mask_first_row = torch.ones((5, 10, npowers), dtype=torch.bool)
        mask_first_row[1:] = 0
        loss_mask_some, _ = compute_search_policy_cross_entropy_sampled(
            fake_model,  # type: ignore
            nest.map(lambda x: fake_time_batch(x.unsqueeze(0)), observations),
            fake_time_batch(orders_tensor.unsqueeze(0).unsqueeze(0)),
            fake_time_batch(probs_tensor.unsqueeze(0).unsqueeze(0)),
            mask=mask_first_row,
        )
        self.assertEqual(loss.item(), loss_full_mask.item())

    def test_entropy(self):
        game = pydipcc.Game()

        observations = FeatureEncoder().encode_inputs([game], input_version=1)

        # Every power's policy is to do all holds
        all_power_prob_distributions = {}
        for power, locs in game.get_orderable_locations().items():
            action = tuple(game.get_all_possible_orders()[loc][0] for loc in locs)
            all_power_prob_distributions[power] = {action: 1.0}

        orders_tensor, probs_tensor = power_prob_distributions_to_tensors_independent(
            all_power_prob_distributions,
            2,
            observations["x_possible_actions"].squeeze(0),
            observations["x_in_adj_phase"].item(),
        )

        self.assertEqual(list(orders_tensor.shape), [7, 2, MAX_SEQ_LEN])
        self.assertEqual(list(probs_tensor.shape), [7, 2])
        # All ones.
        self.assertTrue((probs_tensor[:, 0] - 1 < 1e-4).all())
        # All zeros.
        self.assertTrue((probs_tensor[:, 1] < 1e-4).all())

        self.assertEqual(
            set(
                [
                    ORDER_VOCABULARY[observations["x_possible_actions"].squeeze(0)[0, i, x]]
                    for i, x in enumerate(orders_tensor[0, 0].tolist())
                    if x != -1
                ]
            ),
            set(list(all_power_prob_distributions[POWERS[0]])[0]),
        )

        def fake_model(*, teacher_force_orders, **kwargs):
            logits = torch.zeros(list(teacher_force_orders.shape) + [453])
            return None, None, logits, None

        compute_search_policy_entropy(
            orders_tensor.unsqueeze(0).unsqueeze(0), probs_tensor.unsqueeze(0).unsqueeze(0)
        )

    def test_evs_to_policy(self):
        torch.manual_seed(0)
        evs = torch.rand((2, 3, 7, 10))
        evs[0][0] = -1.0
        evs[evs > 0.9] = -1.0

        evs_to_policy(evs, use_softmax=True)
        evs_to_policy(evs, use_softmax=False)


class TrainerStateTest(unittest.TestCase):
    def _create_state(self):
        net = torch.nn.Linear(10, 20)
        optimizer = torch.optim.SGD(net.parameters(), 1e-3)

        net_state = NetTrainingState(
            model=net, optimizer=optimizer, scheduler=None, args=argparse.Namespace()
        )
        trainer_state = TrainerState(net_state=net_state, value_net_state=None)
        trainer_state.epoch_id = 12
        trainer_state.global_step = 13
        return trainer_state

    def test_epoch_step_propogation(self):
        state = self._create_state()
        self.assertEqual(state.epoch_id, state.net_state.epoch_id)
        self.assertEqual(state.global_step, state.net_state.global_step)
        state.epoch_id += 1
        state.global_step += 2
        self.assertEqual(state.epoch_id, state.net_state.epoch_id)
        self.assertEqual(state.global_step, state.net_state.global_step)
        state.value_net_state = copy.deepcopy(state.net_state)
        state.epoch_id += 1
        state.global_step += 2
        self.assertEqual(state.epoch_id, state.net_state.epoch_id)
        self.assertEqual(state.global_step, state.net_state.global_step)
        self.assertEqual(state.epoch_id, state.value_net_state.epoch_id)
        self.assertEqual(state.global_step, state.value_net_state.global_step)

    def test_save_load_net_state(self):
        net_state = self._create_state().net_state
        with tempfile.TemporaryDirectory() as d:
            path = pathlib.Path(f"{d}/state")
            net_state.save(path)
            loaded_state = net_state.from_dict(torch.load(path), net_state)

        self.assertTrue((net_state.model.weight == loaded_state.model.weight).all())  # type:ignore

    def test_save_load_trainer_state(self):
        state = self._create_state()
        with tempfile.TemporaryDirectory() as d:
            path = pathlib.Path(f"{d}/state2")
            state.save(path)
            loaded_state = state.from_dict(torch.load(path), state)

        self.assertEqual(state.global_step, loaded_state.global_step)
        self.assertEqual(state.epoch_id, loaded_state.epoch_id)
        self.assertEqual(state.global_step, loaded_state.net_state.global_step)
        self.assertEqual(state.epoch_id, loaded_state.net_state.epoch_id)

    def test_save_load_trainer_state_file(self):
        state = self._create_state()
        with tempfile.TemporaryDirectory() as d:
            path = pathlib.Path(f"{d}/state2")
            state.save(path)
            loaded_state = state.load(path, state)

        self.assertEqual(state.global_step, loaded_state.global_step)
        self.assertEqual(state.epoch_id, loaded_state.epoch_id)
        self.assertEqual(state.global_step, loaded_state.net_state.global_step)
        self.assertEqual(state.epoch_id, loaded_state.net_state.epoch_id)
